import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import sys
import os
# 添加上级目录到 sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import random
from torch.optim.lr_scheduler import StepLR
import argparse
import json

from optimizers.lamb import create_lamb_optimizer
from adabelief_pytorch import AdaBelief
from optimizers.ALTO import create_ALTO_optimizer

# 创建参数解析器
parser = argparse.ArgumentParser(description='BiLSTM for Named Entity Recognition')
parser.add_argument('--optimizer', type=str, default='adam', help='Optimizers')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
parser.add_argument('--logname', type=str, default='log.json', help='Log file name')
parser.add_argument('--batch-size', type=int, default=14987, help='batch size')
parser.add_argument('--beta', type=float, default=0.1, help='number of beta of optX')
args = parser.parse_args()


# 设置随机种子
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

if torch.cuda.is_available():
    n_gpu = torch.cuda.device_count()
    print(f"{n_gpu} GPU(s) available.")
else:
    print("GPU is not available, using CPU.")
    n_gpu = 0

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 函数用于从文件中加载数据。它读取指定路径的文件，将数据分为句子和标签。
def load_data(file_path):
    sentences, labels = [], []
    with open(file_path, 'r', encoding='utf-8') as file:
        sentence, label = [], []
        for line in file:
            if line.strip() == '':
                if sentence:
                    sentences.append(sentence)
                    labels.append(label)
                    sentence, label = [], []
            else:
                word, _, _, tag = line.strip().split() 
                sentence.append(word)
                label.append(tag)
        if sentence:
            sentences.append(sentence)
            labels.append(label)
    return sentences, labels

# 函数构建一个词汇表，将每个唯一单词映射到一个索引值
def build_vocab(sentences):
    vocab = {'<pad>': 0, '<unk>': 1}
    for sentence in sentences:
        for word in sentence:
            if word not in vocab:
                vocab[word] = len(vocab)
    return vocab

# 函数将句子和标签编码为索引值，使用之前构建的词汇表和标签索引
def encode(sentences, labels, word2idx, tag2idx):
    encoded_sentences = [[word2idx.get(word, word2idx['<unk>']) for word in sentence] for sentence in sentences]
    encoded_labels = [[tag2idx.get(tag, 0) for tag in label] for label in labels]
    return encoded_sentences, encoded_labels

# 数据集类
class NERDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.labels = labels

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return torch.tensor(self.sentences[idx], dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)
    
    def collate_fn(batch):
        sentences, labels = zip(*batch)
        sentences_padded = pad_sequence(sentences, batch_first=True, padding_value=word2idx['<pad>'])
        labels_padded = pad_sequence(labels, batch_first=True, padding_value=tag2idx['O'])
        return sentences_padded, labels_padded

# 定义模型
class BiLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers, bidirectional, dropout):
        super(BiLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=bidirectional, dropout=dropout)
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):
        # text: [batch_size, seq_len]
        embedded = self.dropout(self.embedding(text))  # [batch_size, seq_len, embedding_dim]
        outputs, (hidden, cell) = self.lstm(embedded)  # outputs: [batch_size, seq_len, hidden_dim * num_directions]
        outputs = self.dropout(outputs)
        predictions = self.fc(outputs)  # [batch_size, seq_len, output_dim]
        return predictions

# 参数设置
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = 9  # 根据实际标签数量调整
NUM_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5
BATCH_SIZE = args.batch_size
N_EPOCHS = args.epochs

# 加载数据
train_sentences, train_labels = load_data('....../conll2003/train.txt')
valid_sentences, valid_labels = load_data('....../conll2003/valid.txt')
test_sentences, test_labels = load_data('....../conll2003/test.txt')

# 构建词汇表和标签映射
word2idx = build_vocab(train_sentences)
tag2idx = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}



# 编码数据
encoded_train_sentences, encoded_train_labels = encode(train_sentences, train_labels, word2idx, tag2idx)
encoded_valid_sentences, encoded_valid_labels = encode(valid_sentences, valid_labels, word2idx, tag2idx)
encoded_test_sentences, encoded_test_labels = encode(test_sentences, test_labels, word2idx, tag2idx)

# 创建数据加载器
train_dataset = NERDataset(encoded_train_sentences, encoded_train_labels)
valid_dataset = NERDataset(encoded_valid_sentences, encoded_valid_labels)
test_dataset = NERDataset(encoded_test_sentences, encoded_test_labels)
# 创建数据加载器时使用 collate_fn
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=NERDataset.collate_fn, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=NERDataset.collate_fn, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=NERDataset.collate_fn, num_workers=4)

# 创建模型实例
model = BiLSTM(len(word2idx), EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, NUM_LAYERS, BIDIRECTIONAL, DROPOUT)

# 计算模型的参数量大小
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params}")

if n_gpu > 1:
    model = nn.DataParallel(model)
model.to(device)

# 优化器
if args.optimizer.lower() == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), 
                                lr=args.lr, 
                                momentum=0.9, 
                                weight_decay=1e-5)
elif args.optimizer.lower() == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), 
                                 lr=args.lr, 
                                 weight_decay=1e-5)
elif args.optimizer.lower() == 'adamw':
    optimizer = torch.optim.AdamW(model.parameters(), 
                                 lr=args.lr, 
                                 weight_decay=1e-5)
elif args.optimizer == 'lamb':
    optimizer = create_lamb_optimizer(model, 
                                        lr=args.lr, 
                                        weight_decay=1e-5)
elif args.optimizer == 'adaBelief':
    optimizer = AdaBelief(model.parameters(), 
                          lr=args.lr, 
                          betas=(0.9, 0.999))
elif args.optimizer == 'ALTO':
    optimizer = create_ALTO_optimizer(model, 
                                       lr=args.lr, 
                                       betas=(0.7, 0.9, 0.99),
                                       weight_decay=1e-5)

else:
    raise ValueError("Unsupported optimizer. Choose 'sgd' or 'adam'.")

scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
criterion = nn.CrossEntropyLoss()

# 训练模型
json_data = []
for epoch in range(N_EPOCHS):
    # 训练阶段
    model.train()
    train_loss = 0
    for sentences, labels in train_loader:
        sentences, labels = sentences.to(device), labels.to(device)

        optimizer.zero_grad()
        predictions = model(sentences)
        predictions = predictions.view(-1, predictions.shape[-1])
        labels = labels.view(-1)
        loss = criterion(predictions, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    
    # 计算平均训练损失
    train_loss /= len(train_loader)

    scheduler.step()
    model.eval()
    valid_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for sentences, labels in valid_loader:
            sentences, labels = sentences.to(device), labels.to(device)

            predictions = model(sentences)
            predictions = predictions.view(-1, predictions.shape[-1])
            labels = labels.view(-1)

            loss = criterion(predictions, labels)
            valid_loss += loss.item()

            _, predicted = torch.max(predictions, 1)

            # 创建一个掩码，标记出非填充的部分
            mask = labels.view(-1) != 0

            # 选择非填充部分的预测和标签
            predicted = predicted[mask]
            labels = labels[mask]

            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    valid_loss /= len(valid_loader)
    val_accuracy = correct / total


    print(f"Epoch {epoch+1}/{N_EPOCHS}, Training Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
    log_metrics = {
        'train_loss': train_loss,
        'valid_loss': valid_loss,
        'val_accuracy': val_accuracy
    }
    json_data.append(log_metrics)

    log_name = os.path.join("log", args.logname)
    with open(log_name, "w") as json_file:
        json.dump(json_data, json_file, indent=4)

print("Train finished")

# 测试模型
model.eval()
all_predictions = []
all_labels = []

with torch.no_grad():
    for sentences, labels in test_loader:
        sentences = sentences.to(device)
        labels = labels.to(device)

        predictions = model(sentences)
        predictions = predictions.view(-1, predictions.shape[-1])
        _, predicted_labels = torch.max(predictions, 1)

        # 创建一个掩码，标记出非填充的部分
        mask = labels.view(-1) != 0  # 确保 mask 是一维的

        # 选择非填充部分的预测和标签
        predicted_labels = predicted_labels[mask]
        labels = labels.view(-1)[mask]

        all_predictions.extend(predicted_labels.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# 计算指标
accuracy = accuracy_score(all_labels, all_predictions)
precision = precision_score(all_labels, all_predictions, average='macro', zero_division=0)
recall = recall_score(all_labels, all_predictions, average='macro', zero_division=0)
f1 = f1_score(all_labels, all_predictions, average='macro', zero_division=0)

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1 Score: {f1:.4f}")

metrics = {
    'accuracy': accuracy,
    'precision': precision,
    'recall': recall,
    'f1_score': f1
}

result_name = os.path.join("log", "testLog_" + args.logname)
with open(result_name, 'w') as f:
    json.dump(metrics, f, indent=4)

print(f"Metrics saved to {args.logname}")